r"""
Adapted from https://github.com/HobbitLong/CMC/blob/f25c37e49196a1fe7dc5f7b559ed43c6fce55f70/models/alexnet.py
"""

import os
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models.resnet import resnet18


class L2Norm(nn.Module):
    def forward(self, x):
        return x / x.norm(p=2, dim=1, keepdim=True)

feat_before_dim = int(os.getenv("ENCODER_FEAT_SIZE", "4096"))

class SmallAlexNet(nn.Module):
    def __init__(self, in_channel=3, feat_dim=128):
        super(SmallAlexNet, self).__init__()

        blocks = []

        # conv_block_1
        blocks.append(nn.Sequential(
            nn.Conv2d(in_channel, 96, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2),
        ))

        # conv_block_2
        blocks.append(nn.Sequential(
            nn.Conv2d(96, 192, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2),
        ))

        # conv_block_3
        blocks.append(nn.Sequential(
            nn.Conv2d(192, 384, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
        ))

        # conv_block_4
        blocks.append(nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
        ))

        # conv_block_5
        blocks.append(nn.Sequential(
            nn.Conv2d(384, 192, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2),
        ))

        # fc6
        blocks.append(nn.Sequential(
            nn.Flatten(),
            nn.Linear(192 * 7 * 7, feat_before_dim, bias=False),  # 256 * 6 * 6 if 224 * 224
            nn.BatchNorm1d(feat_before_dim),
            nn.ReLU(inplace=True),
        ))

        # fc7
        blocks.append(nn.Sequential(
            nn.Linear(feat_before_dim, feat_before_dim, bias=False),
            nn.BatchNorm1d(feat_before_dim),
            nn.ReLU(inplace=True),
        ))

        # fc8
        blocks.append(nn.Sequential(
            nn.Linear(feat_before_dim, feat_dim),
            L2Norm(),
        ))

        self.blocks = nn.ModuleList(blocks)
        self.init_weights_()

    def init_weights_(self):
        def init(m):
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.normal_(m.weight, 0, 0.02)
                if getattr(m, 'bias', None) is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                if getattr(m, 'weight', None) is not None:
                    nn.init.ones_(m.weight)
                if getattr(m, 'bias', None) is not None:
                    nn.init.zeros_(m.bias)

        self.apply(init)

    def forward(self, x, *, layer_index=-1):
        if layer_index < 0:
            layer_index += len(self.blocks)
        for layer in self.blocks[:(layer_index + 1)]:
            x = layer(x)
        return x


class ResNet(nn.Module):

    def __init__(self, in_channel: int = 3, feat_dim: int = 128):
        super().__init__()
        feat_before_dim = 32 * 32
        self.rn = resnet18(num_classes=feat_before_dim)

        self.rn.maxpool = nn.Identity()
        self.rn.conv1 = nn.Conv2d(in_channel, 64,
                kernel_size=3, stride=1, padding=2, bias=False)

        self.predictor = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(32 * 32, feat_dim, bias=False),
            L2Norm(),
        )

    def forward(self, x, layer_index:int = -1):
        if layer_index == -1:
            return self.predictor(self.rn(x))

        if layer_index == -2:
            # try adding the relu as part of the -2th layer
            return F.relu(self.rn(x), inplace=True)
